{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "meaning-hypothetical",
   "metadata": {},
   "source": [
    "## Differential Flatness\n",
    "\n",
    "##### Richard M. Murray, 13 Nov 2021 (updated 7 Jul 2024)\n",
    "\n",
    "This notebook contains an example of using differential flatness as a mechanism for trajectory generation for a nonlinear control system. A differentially flat system is defined by creating an object using the `FlatSystem` class, which has member functions for mapping the system state and input into and out of flat coordinates. The `point_to_point()` function can be used to create a trajectory between two endpoints, written in terms of a set of basis functions defined using the `BasisFamily` class. The resulting trajectory is return as a `SystemTrajectory` object and can be evaluated using the `eval()` member function. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "historic-barbados",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import the packages needed for the examples included in this notebook\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import control as ct\n",
    "import control.flatsys as fs\n",
    "import control.optimal as opt\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "309d3272",
   "metadata": {},
   "source": [
    "## Example: bicycle model\n",
    "\n",
    "To illustrate the methods of generating trajectories using differential flatness, we make use of a simple model for a vehicle navigating in the plane, known as the \"bicycle model\".  The kinematics of this vehicle can be written in terms of the contact point $(x, y)$ and the angle $\\theta$ of the vehicle with respect to the horizontal axis:\n",
    "\n",
    "<table>\n",
    "<tr>\n",
    "    <td width=\"50%\"><img src=\"https://fbswiki.org/wiki/images/5/52/Kincar.png\" width=480></td>\n",
    "    <td width=\"50%\">\n",
    "$$\n",
    "\\begin{aligned}\n",
    "  \\dot x &= \\cos\\theta\\, v \\\\\n",
    "  \\dot y &= \\sin\\theta\\, v \\\\\n",
    "  \\dot\\theta &= \\frac{v}{l} \\tan \\delta\n",
    "\\end{aligned}\n",
    "$$\n",
    "    </td>\n",
    "</tr>\n",
    "</table>\n",
    "\n",
    "The input $v$ represents the velocity of the vehicle and the input $\\delta$ represents the turning rate. The parameter $l$ is the wheelbase."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35efac80",
   "metadata": {},
   "source": [
    "We will generate trajectories for this system that correspond to a \"lane change\", in which we travel longitudinally at a fixed speed for approximately 40 meters, while moving from the right to the left by a distance of 4 meters.\n",
    "\n",
    "It will be convenient to define a function that we will use to plot the results in a uniform way. In addition to the subplot, we also change the size of the figure to make the figure wider."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "involved-riding",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the trajectory in xy coordinates\n",
    "def plot_motion(t, x, ud):\n",
    "    # Set the size of the figure\n",
    "    # plt.figure(figsize=(10, 6))\n",
    "\n",
    "    # Top plot: xy trajectory\n",
    "    plt.subplot(2, 1, 1)\n",
    "    plt.plot(x[0], x[1])\n",
    "    plt.xlabel('x [m]')\n",
    "    plt.ylabel('y [m]')\n",
    "    plt.axis([x0[0], xf[0], x0[1]-1, xf[1]+1])\n",
    "\n",
    "    # Time traces of the state and input\n",
    "    plt.subplot(2, 4, 5)\n",
    "    plt.plot(t, x[1])\n",
    "    plt.ylabel('y [m]')\n",
    "\n",
    "    plt.subplot(2, 4, 6)\n",
    "    plt.plot(t, x[2])\n",
    "    plt.ylabel('theta [rad]')\n",
    "\n",
    "    plt.subplot(2, 4, 7)\n",
    "    plt.plot(t, ud[0])\n",
    "    plt.xlabel(\"Time t [sec]\")\n",
    "    plt.ylabel(\"v [m/s]\")\n",
    "    plt.axis([0, Tf, u0[0] - 1, uf[0] + 1])\n",
    "\n",
    "    plt.subplot(2, 4, 8)\n",
    "    plt.plot(t, ud[1])\n",
    "    plt.xlabel(\"Time t [sec]\")\n",
    "    plt.ylabel(r\"$\\delta$ [rad]\")\n",
    "    plt.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3dc0d2bf",
   "metadata": {},
   "source": [
    "## Flat system mappings\n",
    "\n",
    "To define a flat system, we have to define the functions that take the state and compute the flat \"flag\" (flat outputs and their derivatives) and that take the flat flag and return the state and input.\n",
    "\n",
    "The `forward()` method computes the flat flag given a state and input:\n",
    "```\n",
    "  zflag = sys.forward(x, u)\n",
    "```\n",
    "The `reverse()` method computes the state and input given the flat flag:\n",
    "```\n",
    "  x, u = sys.reverse(zflag)\n",
    "```\n",
    "The flag $\\bar z$ is implemented as a list of flat outputs $z_i$ and\n",
    "their derivatives up to order $q_i$:\n",
    "\n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;`zflag[i][j]` = $z_i^{(j)}$\n",
    "\n",
    "The number of flat outputs must match the number of system inputs.\n",
    "\n",
    "In addition, a flat system is an input/output system and so we define and update function ($f(x, u)$) and output (use `None` to get the full state)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "above-venezuela",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to take states, inputs and return the flat flag\n",
    "def bicycle_flat_forward(x, u, params={}):\n",
    "    # Get the parameter values\n",
    "    b = params.get('wheelbase', 3.)\n",
    "\n",
    "    # Create a list of arrays to store the flat output and its derivatives\n",
    "    zflag = [np.zeros(3), np.zeros(3)]\n",
    "\n",
    "    # Flat output is the x, y position of the rear wheels\n",
    "    zflag[0][0] = x[0]\n",
    "    zflag[1][0] = x[1]\n",
    "\n",
    "    # First derivatives of the flat output\n",
    "    zflag[0][1] = u[0] * np.cos(x[2])  # dx/dt\n",
    "    zflag[1][1] = u[0] * np.sin(x[2])  # dy/dt\n",
    "\n",
    "    # First derivative of the angle\n",
    "    thdot = (u[0]/b) * np.tan(u[1])\n",
    "\n",
    "    # Second derivatives of the flat output (setting vdot = 0)\n",
    "    zflag[0][2] = -u[0] * thdot * np.sin(x[2])\n",
    "    zflag[1][2] =  u[0] * thdot * np.cos(x[2])\n",
    "\n",
    "    return zflag\n",
    "\n",
    "# Function to take the flat flag and return states, inputs\n",
    "def bicycle_flat_reverse(zflag, params={}):\n",
    "    # Get the parameter values\n",
    "    b = params.get('wheelbase', 3.)\n",
    "\n",
    "    # Create a vector to store the state and inputs\n",
    "    x = np.zeros(3)\n",
    "    u = np.zeros(2)\n",
    "\n",
    "    # Given the flat variables, solve for the state\n",
    "    x[0] = zflag[0][0]  # x position\n",
    "    x[1] = zflag[1][0]  # y position\n",
    "    x[2] = np.arctan2(zflag[1][1], zflag[0][1])  # tan(theta) = ydot/xdot\n",
    "\n",
    "    # And next solve for the inputs\n",
    "    u[0] = zflag[0][1] * np.cos(x[2]) + zflag[1][1] * np.sin(x[2])\n",
    "    thdot_v = zflag[1][2] * np.cos(x[2]) - zflag[0][2] * np.sin(x[2])\n",
    "    u[1] = np.arctan2(thdot_v, u[0]**2 / b)\n",
    "\n",
    "    return x, u\n",
    "\n",
    "# Function to compute the RHS of the system dynamics\n",
    "def bicycle_update(t, x, u, params):\n",
    "    b = params.get('wheelbase', 3.)             # get parameter values\n",
    "    dx = np.array([\n",
    "        np.cos(x[2]) * u[0],\n",
    "        np.sin(x[2]) * u[0],\n",
    "        (u[0]/b) * np.tan(u[1])\n",
    "    ])\n",
    "    return dx\n",
    "\n",
    "# Return the entire state as output (instead of default flat outputs)\n",
    "def bicycle_output(t, x, u, params):\n",
    "    return x\n",
    "\n",
    "# Create differentially flat input/output system\n",
    "bicycle_flat = fs.FlatSystem(\n",
    "    bicycle_flat_forward, bicycle_flat_reverse, \n",
    "    bicycle_update, bicycle_output,\n",
    "    inputs=('v', 'delta'), outputs=('x', 'y', 'theta'),\n",
    "    states=('x', 'y', 'theta'), name='bicycle_model')\n",
    "\n",
    "print(bicycle_flat)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75cb8cf6",
   "metadata": {},
   "source": [
    "## Point to point trajectory generation\n",
    "\n",
    "In addition to the flat system description, a set of basis functions\n",
    "$\\phi_i(t)$ must be chosen. The `BasisFamily` class is used to\n",
    "represent the basis functions. A polynomial basis function of the form\n",
    "$1$, $t$, $t^2$, $\\ldots$ can be computed using the `PolyFamily` class,\n",
    "which is initialized by passing the desired order of the polynomial\n",
    "basis set:\n",
    "```\n",
    "polybasis = control.flatsys.PolyFamily(N)\n",
    "```\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feef608a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(fs.BasisFamily.__doc__)\n",
    "print(fs.PolyFamily.__doc__)\n",
    "\n",
    "# Define a set of basis functions to use for the trajectories\n",
    "poly = fs.PolyFamily(6)\n",
    "\n",
    "# Plot out the basis functions\n",
    "t = np.linspace(0, 1.5)\n",
    "for k in range(poly.N):\n",
    "    plt.plot(t, poly(k, t), label=f'k = {k}')\n",
    "    \n",
    "plt.legend()\n",
    "plt.title(\"Polynomial basis functions\")\n",
    "plt.xlabel(\"Time $t$\")\n",
    "plt.ylabel(r\"$\\psi_i(t)$\");"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7aacca93",
   "metadata": {},
   "source": [
    "### Approach 1: point to point solution, no cost or constraints\n",
    "\n",
    "Once the system and basis function have been defined, the\n",
    "`point_to_point()` function can be used to compute a trajectory\n",
    "between initial and final states and inputs:\n",
    "```\n",
    "traj = control.flatsys.point_to_point(sys, Tf, x0, u0, xf, uf, basis=polybasis)\n",
    "```\n",
    "The returned object has class `SystemTrajectory` and can be used\n",
    "to compute the state and input trajectory between the initial and final\n",
    "condition:\n",
    "```\n",
    "xd, ud = traj.eval(timepts)\n",
    "```\n",
    "where `timepts` is a list of times on which the trajectory should be\n",
    "evaluated (e.g., `timepts = numpy.linspace(0, Tf, M)`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "surface-piano",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the endpoints of the trajectory\n",
    "x0 = np.array([0., -2., 0.]); u0 = np.array([10., 0.])\n",
    "xf = np.array([40., 2., 0.]); uf = np.array([10., 0.])\n",
    "Tf = 4\n",
    "\n",
    "# Generate a normalized set of basis functions\n",
    "poly = fs.PolyFamily(6, Tf)\n",
    "\n",
    "# Find a trajectory between the initial condition and the final condition\n",
    "traj = fs.point_to_point(bicycle_flat, Tf, x0, u0, xf, uf, basis=poly)\n",
    "\n",
    "# Create the desired trajectory between the initial and final condition\n",
    "timepts = np.linspace(0, Tf, 500)\n",
    "xd, ud = traj.eval(timepts)\n",
    "\n",
    "# Simulation the open system dynamics with the full input\n",
    "t, y, x = ct.input_output_response(\n",
    "    bicycle_flat, timepts, ud, x0, return_x=True)\n",
    "\n",
    "# Plot the open loop system dynamics\n",
    "plt.figure(1)\n",
    "plt.suptitle(\"Open loop trajectory for unicycle lane change\")\n",
    "plot_motion(t, x, ud)\n",
    "\n",
    "# Make sure the initial and final points are correct\n",
    "print(\"x[0] = \", xd[:, 0])\n",
    "print(\"x[T] = \", xd[:, -1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82a3318a",
   "metadata": {},
   "source": [
    "### A look inside the code\n",
    "\n",
    "The code to solve this problem is inside the file [flatsys.py](https://github.com/python-control/python-control/blob/main/control/flatsys/flatsys.py) in the python-control package.  Here is what operative code inside the `point_to_point()` looks like:\n",
    "\n",
    "    #\n",
    "    # Map the initial and final conditions to flat output conditions\n",
    "    #\n",
    "    # We need to compute the output \"flag\": [z(t), z'(t), z''(t), ...]\n",
    "    # and then evaluate this at the initial and final condition.\n",
    "    #\n",
    "\n",
    "    zflag_T0 = sys.forward(x0, u0)\n",
    "    zflag_Tf = sys.forward(xf, uf)\n",
    "\n",
    "    #\n",
    "    # Compute the matrix constraints for initial and final conditions\n",
    "    #\n",
    "    # This computation depends on the basis function we are using.  It\n",
    "    # essentially amounts to evaluating the basis functions and their\n",
    "    # derivatives at the initial and final conditions.\n",
    "\n",
    "    # Compute the flags for the initial and final states\n",
    "    M_T0 = _basis_flag_matrix(sys, basis, zflag_T0, T0)\n",
    "    M_Tf = _basis_flag_matrix(sys, basis, zflag_Tf, Tf)\n",
    "\n",
    "    # Stack the initial and final matrix/flag for the point to point problem\n",
    "    M = np.vstack([M_T0, M_Tf])\n",
    "    Z = np.hstack([np.hstack(zflag_T0), np.hstack(zflag_Tf)])\n",
    "\n",
    "    #\n",
    "    # Solve for the coefficients of the flat outputs\n",
    "    #\n",
    "    # At this point, we need to solve the equation M alpha = zflag, where M\n",
    "    # is the matrix constrains for initial and final conditions and zflag =\n",
    "    # [zflag_T0; zflag_tf].\n",
    "    #\n",
    "    # If there are no constraints, then we just need to solve a linear\n",
    "    # system of equations => use least squares.  Otherwise, we have a\n",
    "    # nonlinear optimal control problem with equality constraints => use\n",
    "    # scipy.optimize.minimize().\n",
    "    #\n",
    "\n",
    "    # Start by solving the least squares problem\n",
    "    alpha, residuals, rank, s = np.linalg.lstsq(M, Z, rcond=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0397b3e",
   "metadata": {},
   "source": [
    "### Approach #2: add cost function to make lane change quicker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "appreciated-baghdad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define timepoints for evaluation plus basis function to use\n",
    "timepts = np.linspace(0, Tf, 20)\n",
    "basis = fs.PolyFamily(12, Tf)\n",
    "\n",
    "# Define the cost function (penalize lateral error and steering)\n",
    "traj_cost = opt.quadratic_cost(\n",
    "    bicycle_flat, np.diag([0, 0.1, 0]), np.diag([0.1, 1]), x0=xf, u0=uf)\n",
    "\n",
    "# Solve for an optimal solution\n",
    "start_time = time.process_time()\n",
    "traj = fs.point_to_point(\n",
    "    bicycle_flat, timepts, x0, u0, xf, uf, cost=traj_cost, basis=basis,\n",
    ")\n",
    "print(\"* Total time = %5g seconds\\n\" % (time.process_time() - start_time))\n",
    "\n",
    "xd, ud = traj.eval(timepts)\n",
    "\n",
    "plt.figure(2)\n",
    "plt.suptitle(\"Lane change with lateral error + steering penalties\")\n",
    "plot_motion(timepts, xd, ud);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff7363ca",
   "metadata": {},
   "source": [
    "Note that the solution has a very large steering angle (0.2 rad = ~12 degrees)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c533abe",
   "metadata": {},
   "source": [
    "### Approach #3: optimal cost with trajectory constraints\n",
    "\n",
    "To get a smaller steering angle, we add constraints on the inputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "stable-network",
   "metadata": {},
   "outputs": [],
   "source": [
    "constraints = [\n",
    "    opt.input_range_constraint(bicycle_flat, [8, -0.1], [12, 0.1]) ]\n",
    "\n",
    "# Solve for an optimal solution\n",
    "traj = fs.point_to_point(\n",
    "    bicycle_flat, timepts, x0, u0, xf, uf, cost=traj_cost,\n",
    "    trajectory_constraints=constraints, basis=basis,\n",
    ")\n",
    "xd, ud = traj.eval(timepts)\n",
    "\n",
    "plt.figure(3)\n",
    "plt.suptitle(\"Lane change with penalty + steering constraints\")\n",
    "plot_motion(timepts, xd, ud)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "677750b0",
   "metadata": {},
   "source": [
    "## Ideas to explore\n",
    "* Change the number of basis functions\n",
    "* Change the number of time points\n",
    "* Change the type of basis functions: BezierFamily"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1622bccd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a set of basis functions to use for the trajectories\n",
    "poly = fs.BezierFamily(6, 2)\n",
    "\n",
    "# Plot out the basis functions\n",
    "t = np.linspace(0, 2)\n",
    "for k in range(poly.N):\n",
    "    plt.plot(t, poly(k, t), label=f'k = {k}')\n",
    "    \n",
    "plt.legend()\n",
    "plt.title(\"Bezier basis functions\")\n",
    "plt.xlabel(\"Time $t$\")\n",
    "plt.ylabel(r\"$\\psi_i(t)$\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc566fb2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
